{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "46ccab68",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "#全局变量\n",
    "hub_token = open('/root/hub_token.txt').read().strip()\n",
    "repo_id = 'lansinuote/diffusion.2.textual_inversion'\n",
    "push_to_hub = True\n",
    "checkpoint = 'runwayml/stable-diffusion-v1-5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ee4074ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DDPMScheduler {\n",
       "   \"_class_name\": \"DDPMScheduler\",\n",
       "   \"_diffusers_version\": \"0.12.1\",\n",
       "   \"beta_end\": 0.012,\n",
       "   \"beta_schedule\": \"scaled_linear\",\n",
       "   \"beta_start\": 0.00085,\n",
       "   \"clip_sample\": false,\n",
       "   \"num_train_timesteps\": 1000,\n",
       "   \"prediction_type\": \"epsilon\",\n",
       "   \"set_alpha_to_one\": false,\n",
       "   \"skip_prk_steps\": true,\n",
       "   \"steps_offset\": 1,\n",
       "   \"trained_betas\": null,\n",
       "   \"variance_type\": \"fixed_small\"\n",
       " },\n",
       " CLIPTokenizer(name_or_path='runwayml/stable-diffusion-v1-5', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken(\"<|startoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '<|endoftext|>'}))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from diffusers import DDPMScheduler\n",
    "from transformers import CLIPTokenizer\n",
    "\n",
    "scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')\n",
    "tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')\n",
    "\n",
    "scheduler, tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "385799b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Pushing split train to the Hub.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ad9beac8f3e45d99e454794b12bdd55",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9e1094f808c84f8ea82d122064239279",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration lansinuote--diffusion.2.textual_inversion-3f03ac153edd3f09\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.2.textual_inversion-3f03ac153edd3f09/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "332612507b944868883d451f6ba9f26f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f571a25a994948558ceba0ae781b121a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/1.66M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e34743bf4b294c1d830e99db2960ba73",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/6 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.2.textual_inversion-3f03ac153edd3f09/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(Dataset({\n",
       "     features: ['image'],\n",
       "     num_rows: 6\n",
       " }),\n",
       " {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1167x1010>})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import Dataset, DatasetDict, load_dataset\n",
    "import PIL.Image\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def get_dataset():\n",
    "    images = [{\n",
    "        'image': PIL.Image.open('images/%d.jpeg' % i)\n",
    "    } for i in range(6)]\n",
    "\n",
    "    dataset = Dataset.from_list(images)\n",
    "\n",
    "    return DatasetDict({'train': dataset})\n",
    "\n",
    "\n",
    "if push_to_hub:\n",
    "    dataset = get_dataset()\n",
    "    dataset.push_to_hub(repo_id=repo_id, token=hub_token)\n",
    "\n",
    "#直接使用我处理好的数据集\n",
    "dataset = load_dataset(path=repo_id, split='train')\n",
    "\n",
    "dataset, dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "880422b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Dataset({\n",
       "     features: ['image'],\n",
       "     num_rows: 6\n",
       " }),\n",
       " {'input_ids': tensor([49406,   320,   886,  1125,   539,   518,   283,  2368,   268,  5988,\n",
       "            285, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "          49407, 49407, 49407, 49407, 49407, 49407, 49407]),\n",
       "  'pixel_values': tensor([[[ 0.0196,  0.0196,  0.0196,  ..., -0.2157, -0.2314, -0.2392],\n",
       "           [ 0.0196,  0.0196,  0.0196,  ..., -0.2078, -0.2314, -0.2392],\n",
       "           [ 0.0196,  0.0196,  0.0196,  ..., -0.2078, -0.2157, -0.2235],\n",
       "           ...,\n",
       "           [ 0.5451,  0.5529,  0.5529,  ..., -0.1765, -0.2000, -0.2941],\n",
       "           [ 0.5373,  0.5451,  0.5529,  ..., -0.4039, -0.4510, -0.5608],\n",
       "           [ 0.5373,  0.5451,  0.5451,  ..., -0.6549, -0.6784, -0.7098]],\n",
       "  \n",
       "          [[ 0.0902,  0.0902,  0.0902,  ...,  0.2549,  0.2549,  0.2471],\n",
       "           [ 0.0902,  0.0902,  0.0902,  ...,  0.2549,  0.2549,  0.2471],\n",
       "           [ 0.0902,  0.0902,  0.0902,  ...,  0.2549,  0.2471,  0.2392],\n",
       "           ...,\n",
       "           [ 0.2784,  0.2863,  0.2863,  ..., -0.2000, -0.2235, -0.3176],\n",
       "           [ 0.2706,  0.2784,  0.2863,  ..., -0.4275, -0.4745, -0.5843],\n",
       "           [ 0.2706,  0.2784,  0.2784,  ..., -0.6627, -0.6941, -0.7176]],\n",
       "  \n",
       "          [[ 0.2078,  0.2078,  0.2078,  ..., -0.1059, -0.1137, -0.1216],\n",
       "           [ 0.2078,  0.2078,  0.2078,  ..., -0.1059, -0.1137, -0.1216],\n",
       "           [ 0.2078,  0.2078,  0.2078,  ..., -0.1059, -0.1137, -0.1216],\n",
       "           ...,\n",
       "           [-0.0196, -0.0118, -0.0118,  ..., -0.1294, -0.1529, -0.2471],\n",
       "           [-0.0275, -0.0196, -0.0118,  ..., -0.3569, -0.4039, -0.5137],\n",
       "           [-0.0275, -0.0196, -0.0196,  ..., -0.6000, -0.6314, -0.6549]]])})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torchvision\n",
    "import random\n",
    "\n",
    "#描述文本\n",
    "texts = [\n",
    "    'a photo of a <cat-toy>', 'a rendering of a <cat-toy>',\n",
    "    'a cropped photo of the <cat-toy>', 'the photo of a <cat-toy>',\n",
    "    'a photo of a clean <cat-toy>', 'a photo of a dirty <cat-toy>',\n",
    "    'a dark photo of the <cat-toy>', 'a photo of my <cat-toy>',\n",
    "    'a photo of the cool <cat-toy>', 'a close-up photo of a <cat-toy>',\n",
    "    'a bright photo of the <cat-toy>', 'a cropped photo of a <cat-toy>',\n",
    "    'a photo of the <cat-toy>', 'a good photo of the <cat-toy>',\n",
    "    'a photo of one <cat-toy>', 'a close-up photo of the <cat-toy>',\n",
    "    'a rendition of the <cat-toy>', 'a photo of the clean <cat-toy>',\n",
    "    'a rendition of a <cat-toy>', 'a photo of a nice <cat-toy>',\n",
    "    'a good photo of a <cat-toy>', 'a photo of the nice <cat-toy>',\n",
    "    'a photo of the small <cat-toy>', 'a photo of the weird <cat-toy>',\n",
    "    'a photo of the large <cat-toy>', 'a photo of a cool <cat-toy>',\n",
    "    'a photo of a small <cat-toy>'\n",
    "]\n",
    "\n",
    "#数据增强\n",
    "compose = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.RandomHorizontalFlip(p=0.5),\n",
    "])\n",
    "\n",
    "\n",
    "def f(data):\n",
    "    #编码文字\n",
    "    #77 = tokenizer.model_max_length\n",
    "    input_ids = tokenizer(random.choice(texts),\n",
    "                          padding='max_length',\n",
    "                          truncation=True,\n",
    "                          max_length=77,\n",
    "                          return_tensors='pt')['input_ids']\n",
    "\n",
    "    #编码图片\n",
    "    pixel_values = []\n",
    "    for i in range(len(data['image'])):\n",
    "        image = data['image'][i]\n",
    "\n",
    "        #数据增强\n",
    "        image = compose(image)\n",
    "\n",
    "        #尺寸缩放\n",
    "        image = image.resize((512, 512), resample=PIL.Image.Resampling.BICUBIC)\n",
    "\n",
    "        #数值操作\n",
    "        image = np.array(image).astype(np.uint8)\n",
    "        image = image / 127.5 - 1.0\n",
    "        image = image.astype(np.float32)\n",
    "\n",
    "        #转tensor,把通道数放在前面\n",
    "        #[512, 512, 3] -> [3, 512, 512]\n",
    "        image = torch.from_numpy(image).permute(2, 0, 1)\n",
    "\n",
    "        pixel_values.append(image)\n",
    "\n",
    "    return {'input_ids': input_ids, 'pixel_values': pixel_values}\n",
    "\n",
    "\n",
    "dataset = dataset.with_transform(f)\n",
    "\n",
    "dataset, dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "51d14644",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6,\n",
       " {'input_ids': tensor([[49406,   320,  1125,   539,   518,  2442,   283,  2368,   268,  5988,\n",
       "             285, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407]]),\n",
       "  'pixel_values': tensor([[[[ 0.5529,  0.5451,  0.5451,  ...,  0.4902,  0.4902,  0.4824],\n",
       "            [ 0.5529,  0.5529,  0.5451,  ...,  0.4745,  0.4745,  0.4902],\n",
       "            [ 0.5608,  0.5529,  0.5529,  ...,  0.4745,  0.4745,  0.4824],\n",
       "            ...,\n",
       "            [ 0.6549,  0.6627,  0.6627,  ...,  0.6863,  0.6784,  0.6784],\n",
       "            [ 0.6549,  0.6549,  0.6627,  ...,  0.6941,  0.6941,  0.6941],\n",
       "            [ 0.6706,  0.6627,  0.6627,  ...,  0.6941,  0.6941,  0.6941]],\n",
       "  \n",
       "           [[ 0.6078,  0.6000,  0.6000,  ...,  0.4588,  0.4588,  0.4510],\n",
       "            [ 0.6078,  0.6078,  0.6000,  ...,  0.4431,  0.4431,  0.4588],\n",
       "            [ 0.6157,  0.6078,  0.6078,  ...,  0.4431,  0.4431,  0.4510],\n",
       "            ...,\n",
       "            [ 0.3647,  0.3725,  0.3725,  ...,  0.3490,  0.3412,  0.3412],\n",
       "            [ 0.3647,  0.3647,  0.3725,  ...,  0.3569,  0.3569,  0.3569],\n",
       "            [ 0.3804,  0.3725,  0.3725,  ...,  0.3569,  0.3569,  0.3569]],\n",
       "  \n",
       "           [[ 0.6706,  0.6627,  0.6627,  ...,  0.4510,  0.4510,  0.4431],\n",
       "            [ 0.6706,  0.6706,  0.6627,  ...,  0.4353,  0.4353,  0.4510],\n",
       "            [ 0.6784,  0.6706,  0.6706,  ...,  0.4353,  0.4353,  0.4431],\n",
       "            ...,\n",
       "            [ 0.0118,  0.0196,  0.0196,  ..., -0.0667, -0.0745, -0.0745],\n",
       "            [ 0.0118,  0.0118,  0.0196,  ..., -0.0588, -0.0588, -0.0588],\n",
       "            [ 0.0275,  0.0196,  0.0196,  ..., -0.0588, -0.0588, -0.0588]]]])})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "\n",
    "len(loader), next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f58e68f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder 12306.048\n",
      "vae 8365.3863\n",
      "unet 85952.0964\n"
     ]
    }
   ],
   "source": [
    "from transformers import CLIPTextModel\n",
    "from diffusers import AutoencoderKL, UNet2DConditionModel\n",
    "\n",
    "encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')\n",
    "vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')\n",
    "unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')\n",
    "\n",
    "\n",
    "def print_model_size(name, model):\n",
    "    print(name, sum(i.numel() for i in model.parameters()) / 10000)\n",
    "\n",
    "\n",
    "print_model_size('encoder', encoder)\n",
    "print_model_size('vae', vae)\n",
    "print_model_size('unet', unet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c0fc8a01",
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_new_word():\n",
    "    #字典里添加新词\n",
    "    tokenizer.add_tokens('<cat-toy>')\n",
    "\n",
    "    #扩展encoder的embed层,添加一个新空间用于容纳新词\n",
    "    encoder.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "    #取新旧两个词的id\n",
    "    old_id = tokenizer.convert_tokens_to_ids('toy')\n",
    "    new_id = tokenizer.convert_tokens_to_ids('<cat-toy>')\n",
    "\n",
    "    embed = encoder.get_input_embeddings().weight.data\n",
    "\n",
    "    #以旧词来初始化新词\n",
    "    embed[new_id] = embed[old_id]\n",
    "\n",
    "\n",
    "init_new_word()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f64a3ff9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Embedding(77, 768)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#这两个模型不更新参数\n",
    "vae.requires_grad_(False)\n",
    "unet.requires_grad_(False)\n",
    "\n",
    "#只训练encoder.text_model.embeddings.token_embedding层,其他全部锁定\n",
    "encoder.text_model.encoder.requires_grad_(False)\n",
    "encoder.text_model.final_layer_norm.requires_grad_(False)\n",
    "encoder.text_model.embeddings.position_embedding.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a76101f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(AdamW (\n",
       " Parameter Group 0\n",
       "     amsgrad: False\n",
       "     betas: (0.9, 0.999)\n",
       "     capturable: False\n",
       "     eps: 1e-08\n",
       "     foreach: None\n",
       "     lr: 0.0005\n",
       "     maximize: False\n",
       "     weight_decay: 0.01\n",
       " ),\n",
       " MSELoss())"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from diffusers.optimization import get_scheduler\n",
    "\n",
    "optimizer = torch.optim.AdamW(encoder.get_input_embeddings().parameters(),\n",
    "                              lr=5e-4,\n",
    "                              betas=(0.9, 0.999),\n",
    "                              weight_decay=0.01,\n",
    "                              eps=1e-8)\n",
    "\n",
    "criterion = torch.nn.MSELoss()\n",
    "\n",
    "optimizer, criterion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a0ca6c9a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0532, grad_fn=<MseLossBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_loss(data):\n",
    "    device = data['input_ids'].device\n",
    "\n",
    "    #编码文字\n",
    "    #[1, 77] -> [1, 77, 768]\n",
    "    out_encoder = encoder(data['input_ids'])[0]\n",
    "\n",
    "    #vae计算特征图\n",
    "    #[1, 3, 512, 512] -> [1, 4, 64, 64]\n",
    "    out_vae = vae.encode(data['pixel_values']).latent_dist.sample().detach()\n",
    "\n",
    "    #0.18215 = vae.config.scaling_factor\n",
    "    out_vae = out_vae * 0.18215\n",
    "\n",
    "    #随机噪声\n",
    "    #[1, 4, 64, 64]\n",
    "    noise = torch.randn_like(out_vae)\n",
    "\n",
    "    #随机噪声步\n",
    "    #1000 = scheduler.config.num_train_timesteps\n",
    "    #1 = b\n",
    "    noise_step = torch.randint(0, 1000, (1, ), device=device).long()\n",
    "\n",
    "    #添加噪声\n",
    "    #[1, 4, 64, 64]\n",
    "    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)\n",
    "\n",
    "    #unet从噪声图中把噪声计算出来\n",
    "    #[1, 4, 64, 64],[1, 77, 768] -> [1, 4, 64, 64]\n",
    "    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample\n",
    "\n",
    "    return criterion(out_unet, noise)\n",
    "\n",
    "\n",
    "get_loss({\n",
    "    'input_ids': torch.ones(1, 77).long(),\n",
    "    'pixel_values': torch.randn(1, 3, 512, 512)\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a93aae56",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Cloning https://huggingface.co/lansinuote/diffusion.2.textual_inversion into local empty directory.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.13851985079236329\n",
      "20 3.809695530508179\n",
      "40 4.4042913969024085\n",
      "60 4.388479830056895\n",
      "80 4.4977232606033795\n",
      "100 4.548502518853638\n",
      "120 3.693540668406058\n",
      "140 4.156299739843234\n",
      "160 3.7243479269673117\n",
      "180 3.697449386701919\n",
      "200 3.546344161964953\n",
      "220 4.308455798542127\n",
      "240 3.784453941101674\n",
      "260 3.2318183617899194\n",
      "280 4.367512361612171\n",
      "300 3.0064080328447744\n",
      "320 4.068744384741876\n",
      "340 3.318315677694045\n",
      "360 4.715249826433137\n",
      "380 3.7152999051613733\n",
      "400 3.955495126952883\n",
      "420 4.009544232976623\n",
      "440 4.505315208341926\n",
      "460 4.091954343370162\n",
      "480 4.5167977513046935\n",
      "500 3.992238358419854\n",
      "520 3.93575657781912\n",
      "540 4.0084115316858515\n",
      "560 3.725144908821676\n",
      "580 4.794721863931045\n",
      "600 4.073796798766125\n",
      "620 4.574912750336807\n",
      "640 4.365607368235942\n",
      "660 4.622854062763508\n",
      "680 4.699603373650461\n",
      "700 3.90476599894464\n",
      "720 4.237309526128229\n",
      "740 4.308226598310284\n",
      "760 4.023056301870383\n",
      "780 3.4964061353821307\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fc6310364e2a4e79adbd8debdcd324d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "182cee777c4145348ea7b9a91336306f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file unet/diffusion_pytorch_model.bin:   0%|          | 32.0k/3.20G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8df18ae9fbd44aaab33673f6ccd894e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file text_encoder/pytorch_model.bin:   0%|          | 32.0k/470M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f07358fa8474bf2911557a7885d563e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file safety_checker/pytorch_model.bin:   0%|          | 32.0k/1.13G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e631482a5d3d46ad922d0cf9db2b18f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file vae/diffusion_pytorch_model.bin:   0%|          | 32.0k/319M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "remote: Scanning LFS files for validity...\u001b[K\n",
      "remote: LFS file scan complete.\u001b[K\n",
      "To https://user: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@huggingface.co/lansinuote/diffusion.2.textual_inversion\n",
      "   fff28ba..d1d1389  main -> main\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from diffusers import StableDiffusionPipeline\n",
    "from huggingface_hub import Repository, create_repo\n",
    "\n",
    "\n",
    "def train():\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    encoder.to(device)\n",
    "    vae.to(device)\n",
    "    unet.to(device)\n",
    "    encoder.train()\n",
    "\n",
    "    loss_sum = 0\n",
    "    for epoch in range(800):\n",
    "        for i, data in enumerate(loader):\n",
    "            for k in data.keys():\n",
    "                data[k] = data[k].to(device)\n",
    "\n",
    "            loss = get_loss(data) / 4\n",
    "            loss.backward()\n",
    "\n",
    "            #积累更新\n",
    "            if (epoch * len(loader) + i) % 4 == 0:\n",
    "                optimizer.step()\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "            loss_sum += loss.item()\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            print(epoch, loss_sum)\n",
    "            loss_sum = 0\n",
    "\n",
    "    #保存\n",
    "    StableDiffusionPipeline.from_pretrained(\n",
    "        checkpoint,\n",
    "        text_encoder=encoder,\n",
    "        vae=vae,\n",
    "        unet=unet,\n",
    "        tokenizer=tokenizer).save_pretrained('./save')\n",
    "\n",
    "\n",
    "if push_to_hub:\n",
    "    create_repo(repo_id, exist_ok=True, token=hub_token)\n",
    "    repo = Repository('./save', clone_from=repo_id, token=hub_token)\n",
    "    train()\n",
    "    repo.push_to_hub()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
